{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CPU and GPU Operator Customization with Taichi\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/brainpy/blob/master/docs_version2/tutorial_advanced/operator_custom_with_taichi.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs_version2/tutorial_advanced/operator_custom_with_taichi.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "This functionality is only available for ``brainpylib>=0.2.0``. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## English version"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Brain dynamics is sparse and event-driven, however, proprietary operators for brain dynamics are not well abstracted and summarized. As a result, we are often faced with the need to customize operators. In this tutorial, we will explore how to customize brain dynamics operators using taichi.\n",
    "\n",
    "Start by importing the relevant Python package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import brainpy.math as bm\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import pytest\n",
    "import platform\n",
    "\n",
    "import taichi as ti\n",
    "\n",
    "bm.set_platform('cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Basic Structure of Custom Operators\n",
    "Taichi uses Python functions and decorators to define custom operators. Here is a basic structure of a custom operator:\n",
    "\n",
    "```python\n",
    "@ti.kernel\n",
    "def my_kernel(arg1: ti.types.ndarray(), arg2: ti.types.ndarray()):\n",
    "    # Internal logic of the operator\n",
    "```\n",
    "The @ti.kernel decorator tells Taichi that this is a function that requires special compilation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining Helper Functions\n",
    "When defining complex custom operators, you can use the @ti.func decorator to define helper functions. These functions can be called inside the kernel function:\n",
    "\n",
    "```python\n",
    "@ti.func\n",
    "def helper_func(x: ti.f32) -> ti.f32:\n",
    "    # Auxiliary computation\n",
    "    return x * 2\n",
    "\n",
    "@ti.kernel\n",
    "def my_kernel(arg: ti.types.ndarray()):\n",
    "    for i in ti.ndrange(arg.shape[0]):\n",
    "        arg[i] *= helper_func(arg[i])\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example: Custom Event Processing Operator\n",
    "The following example demonstrates how to customize an event processing operator:\n",
    "\n",
    "```python\n",
    "@ti.func\n",
    "def get_weight(weight: ti.types.ndarray(ndim=0)) -> ti.f32:\n",
    "    return weight[None]\n",
    "\n",
    "@ti.func\n",
    "def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):\n",
    "    out[index] += weight_val\n",
    "\n",
    "@ti.kernel\n",
    "def event_ell_cpu(indices: ti.types.ndarray(ndim=2),\n",
    "                  vector: ti.types.ndarray(ndim=1),\n",
    "                  weight: ti.types.ndarray(ndim=0),\n",
    "                  out: ti.types.ndarray(ndim=1)):\n",
    "    weight_val = get_weight(weight)\n",
    "    num_rows, num_cols = indices.shape\n",
    "    ti.loop_config(serialize=True)\n",
    "    for i in range(num_rows):\n",
    "        if vector[i]:\n",
    "            for j in range(num_cols):\n",
    "                update_output(out, indices[i, j], weight_val)\n",
    "```\n",
    "In the declaration of parameters, the last few parameters need to be output parameters so that Taichi can compile correctly. This operator event_ell_cpu receives indices, vectors, weights, and output arrays, and updates the output arrays according to the provided logic."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Registering and Using Custom Operators\n",
    "After defining a custom operator, it can be registered into a specific framework and used where needed. When registering, you can specify cpu_kernel and gpu_kernel, so the operator can run on different devices. Specify the outs parameter when calling, using `jax.ShapeDtypeStruct` to define the shape and data type of the output.\n",
    "\n",
    "Note: Maintain the order of the operator's declared parameters consistent with the order when calling.\n",
    "\n",
    "```python\n",
    "import brainpy.math as bm\n",
    "\n",
    "# Taichi operator registration\n",
    "prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)\n",
    "\n",
    "# Using the operator\n",
    "def test_taichi_op():\n",
    "    # Create input data\n",
    "    # ...\n",
    "\n",
    "    # Call the custom operator\n",
    "    out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n",
    "\n",
    "    # Output the result\n",
    "    print(out)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Taichi Optimization Methods\n",
    "#### For Loop Decorators\n",
    "Taichi kernels automatically parallelize for-loops in the outermost scope. Our compiler sets the settings automatically to best explore the target architecture. Nonetheless, for Ninjas seeking the final few percent of speed, we provide several APIs to allow developers to fine-tune their programs. Specifying a proper block_dim is key.\n",
    "\n",
    "You can use `ti.loop_config` to set the loop directives for the next for loop. Available directives are:\n",
    "\n",
    "* **parallelize**: Sets the number of threads to use on CPU\n",
    "* **block_dim**: Sets the number of threads in a block on GPU\n",
    "* **serialize**: If you set **serialize** to `True`, the for loop will run serially, and you can write break statements inside it (Only applies on range/ndrange fors). Equals to setting **parallelize** to 1.\n",
    "\n",
    "```python\n",
    "@ti.kernel\n",
    "def break_in_serial_for() -> ti.i32:\n",
    "    a = 0\n",
    "    ti.loop_config(serialize=True)\n",
    "    for i in range(100):  # This loop runs serially\n",
    "        a += i\n",
    "        if i == 10:\n",
    "            break\n",
    "    return a\n",
    "\n",
    "break_in_serial_for()  # returns 55\n",
    "n = 128\n",
    "val = ti.field(ti.i32, shape=n)\n",
    "@ti.kernel\n",
    "def fill():\n",
    "    ti.loop_config(parallelize=8, block_dim=16)\n",
    "    # If the kernel is run on the CPU backend, 8 threads will be used to run it\n",
    "    # If the kernel is run on the CUDA backend, each block will have 16 threads.\n",
    "    for i in range(n):\n",
    "        val[i] = i\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Complete example\n",
    "Here is a complete example showing how to implement a simple operator using the taichi custom operator:\n",
    "\n",
    "```python\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import taichi as ti\n",
    "import pytest\n",
    "import platform\n",
    "\n",
    "import brainpy.math as bm\n",
    "\n",
    "bm.set_platform('cpu')\n",
    "\n",
    "@ti.func\n",
    "def get_weight(weight: ti.types.ndarray(ndim=0)) -> ti.f32:\n",
    "  return weight[None]\n",
    "\n",
    "\n",
    "@ti.func\n",
    "def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):\n",
    "  out[index] += weight_val\n",
    "\n",
    "\n",
    "@ti.kernel\n",
    "def event_ell_cpu(indices: ti.types.ndarray(ndim=2),\n",
    "                  vector: ti.types.ndarray(ndim=1),\n",
    "                  weight: ti.types.ndarray(ndim=0),\n",
    "                  out: ti.types.ndarray(ndim=1)):\n",
    "  weight_val = get_weight(weight)\n",
    "  num_rows, num_cols = indices.shape\n",
    "  ti.loop_config(serialize=True)\n",
    "  for i in range(num_rows):\n",
    "    if vector[i]:\n",
    "      for j in range(num_cols):\n",
    "        update_output(out, indices[i, j], weight_val)\n",
    "\n",
    "@ti.kernel\n",
    "def event_ell_gpu(indices: ti.types.ndarray(ndim=2),\n",
    "                  vector: ti.types.ndarray(ndim=1), \n",
    "                  weight: ti.types.ndarray(ndim=0), \n",
    "                  out: ti.types.ndarray(ndim=1)):\n",
    "  weight_val = get_weight(weight)\n",
    "  num_rows, num_cols = indices.shape\n",
    "  for i in range(num_rows):\n",
    "    if vector[i]:\n",
    "      for j in range(num_cols):\n",
    "        update_output(out, indices[i, j], weight_val)\n",
    "\n",
    "prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)\n",
    "\n",
    "\n",
    "def test_taichi_op_register():\n",
    "  s = 1000\n",
    "  indices = bm.random.randint(0, s, (s, 1000))\n",
    "  vector = bm.random.rand(s) < 0.1\n",
    "\n",
    "  out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n",
    "\n",
    "  out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n",
    "\n",
    "  print(out)\n",
    "\n",
    "test_taichi_op_register()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### More Examples\n",
    "For more examples, please refer to: \n",
    "- [event/_csr_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/event/_csr_matvec.py)\n",
    "- [sparse/_csr_mv_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/sparse/_csr_mv.py)\n",
    "- [jitconn/_event_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/jitconn/_event_matvec.py)\n",
    "- [jitconn/_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/jitconn/_matvec.py)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Clean the cache of taichi kernels\n",
    "Because brainpy fuse taichi and JAX using taichi AOT method, the taichi kernels will be cached in the system. If you want to clean the cache, you can use the following code:\n",
    "\n",
    "```python\n",
    "import brainpy.math as bm\n",
    "\n",
    "bm.clean_caches()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 中文版"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "大脑动力学具有稀疏和事件驱动的特性，然而，大脑动力学的专有算子并没有很好的抽象和总结。因此，我们往往面临着自定义算子的需求。在这个教程中，我们将探索如何使用Numba来自定义脑动力学算子。\n",
    "\n",
    "首先引入相关的Python包。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import brainpy.math as bm\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import pytest\n",
    "import platform\n",
    "\n",
    "import taichi as ti\n",
    "\n",
    "bm.set_platform('cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 自定义算子的基本结构\n",
    "taichi 使用 Python 函数和装饰器来定义自定义算子。以下是一个基本的自定义算子结构：\n",
    "\n",
    "```python\n",
    "@ti.kernel\n",
    "def my_kernel(arg1: ti.types.ndarray(), arg2: ti.types.ndarray()):\n",
    "    # 算子内部的计算逻辑\n",
    "```\n",
    "其中，@ti.kernel 装饰器用于告诉 Taichi 这是一个需要特殊编译的函数。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义辅助函数\n",
    "在定义复杂的自定义算子时，可以使用 @ti.func 装饰器定义辅助函数。这些函数可以在 kernel 函数内部调用：\n",
    "\n",
    "```python\n",
    "@ti.func\n",
    "def helper_func(x: ti.f32) -> ti.f32:\n",
    "    # 辅助计算\n",
    "    return x * 2\n",
    "\n",
    "@ti.kernel\n",
    "def my_kernel(arg: ti.types.ndarray()):\n",
    "    for i in ti.ndrange(arg.shape[0]):\n",
    "        arg[i] *= helper_func(arg[i])\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 示例：自定义事件处理算子\n",
    "下面的例子展示了如何自定义一个处理事件的算子：\n",
    "\n",
    "```python\n",
    "@ti.func\n",
    "def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:\n",
    "    return weight[0]\n",
    "\n",
    "@ti.func\n",
    "def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):\n",
    "    out[index] += weight_val\n",
    "\n",
    "@ti.kernel\n",
    "def event_ell_cpu(indices: ti.types.ndarray(ndim=2),\n",
    "                  vector: ti.types.ndarray(ndim=1),\n",
    "                  weight: ti.types.ndarray(ndim=1),\n",
    "                  out: ti.types.ndarray(ndim=1)):\n",
    "    weight_val = get_weight(weight)\n",
    "    num_rows, num_cols = indices.shape\n",
    "    ti.loop_config(serialize=True)\n",
    "    for i in range(num_rows):\n",
    "        if vector[i]:\n",
    "            for j in range(num_cols):\n",
    "                update_output(out, indices[i, j], weight_val)\n",
    "```\n",
    "在参数的声明上，需要最后的几个参数是输出参数，这样 Taichi 才能正确的编译。这个算子 event_ell_cpu 接收索引、向量、权重和输出数组，并根据提供的逻辑更新输出数组。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 注册并使用自定义算子\n",
    "在定义了自定义算子之后，可以将其注册到特定框架中，并在需要的地方使用它。在注册时可以指定`cpu_kernel`和`gpu_kernel`，这样算子就可以在不同的设备上运行。并在调用中指定`outs`参数，用`jax.ShapeDtypeStruct`来指定输出的形状和数据类型。\n",
    "\n",
    "注意： 在算子声明的参数与调用时需要保持顺序的一致。\n",
    "\n",
    "\n",
    "```python\n",
    "import brainpy.math as bm\n",
    "\n",
    "# Taichi 算子注册\n",
    "prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)\n",
    "\n",
    "# 算子使用\n",
    "def test_taichi_op():\n",
    "    # 创建输入数据\n",
    "    # ...\n",
    "\n",
    "    # 调用自定义算子\n",
    "    out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n",
    "\n",
    "    # 输出结果\n",
    "    print(out)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### taichi优化方法\n",
    "\n",
    "#### for循环装饰器\n",
    "Taichi 内核会自动并行化最外层作用域中的 for 循环。我们的编译器会自动设置配置，以最佳方式探索目标架构。然而，对于追求最后几个百分点速度的高手，我们提供了几个 API 来允许开发者精细调整他们的程序。指定合适的 `block_dim` 是关键。\n",
    "\n",
    "你可以使用 `ti.loop_config` 来设置下一个 for 循环的循环指令。可用的指令有：\n",
    "\n",
    "* **parallelize**：在 CPU 上使用的线程数\n",
    "* **block_dim**：在 GPU 上一个块中的线程数\n",
    "* **serialize**：如果你将 **serialize** 设置为 `True`，for 循环将会串行执行，你可以在其中编写 break 语句（仅适用于 range/ndrange 循环）。等同于将 **parallelize** 设置为 1。\n",
    "\n",
    "```python\n",
    "@ti.kernel\n",
    "def break_in_serial_for() -> ti.i32:\n",
    "    a = 0\n",
    "    ti.loop_config(serialize=True)\n",
    "    for i in range(100):  # This loop runs serially\n",
    "        a += i\n",
    "        if i == 10:\n",
    "            break\n",
    "    return a\n",
    "\n",
    "break_in_serial_for()  # returns 55\n",
    "n = 128\n",
    "val = ti.field(ti.i32, shape=n)\n",
    "@ti.kernel\n",
    "def fill():\n",
    "    ti.loop_config(parallelize=8, block_dim=16)\n",
    "    # If the kernel is run on the CPU backend, 8 threads will be used to run it\n",
    "    # If the kernel is run on the CUDA backend, each block will have 16 threads.\n",
    "    for i in range(n):\n",
    "        val[i] = i\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 完整示例\n",
    "下面是一个完整的示例，展示了如何使用 taichi 自定义算子来实现一个简单的算子：\n",
    "\n",
    "```python\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import taichi as ti\n",
    "import pytest\n",
    "import platform\n",
    "\n",
    "import brainpy.math as bm\n",
    "\n",
    "bm.set_platform('cpu')\n",
    "\n",
    "@ti.func\n",
    "def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:\n",
    "  return weight[0]\n",
    "\n",
    "\n",
    "@ti.func\n",
    "def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):\n",
    "  out[index] += weight_val\n",
    "\n",
    "\n",
    "@ti.kernel\n",
    "def event_ell_cpu(indices: ti.types.ndarray(ndim=2),\n",
    "                  vector: ti.types.ndarray(ndim=1),\n",
    "                  weight: ti.types.ndarray(ndim=1),\n",
    "                  out: ti.types.ndarray(ndim=1)):\n",
    "  weight_val = get_weight(weight)\n",
    "  num_rows, num_cols = indices.shape\n",
    "  ti.loop_config(serialize=True)\n",
    "  for i in range(num_rows):\n",
    "    if vector[i]:\n",
    "      for j in range(num_cols):\n",
    "        update_output(out, indices[i, j], weight_val)\n",
    "\n",
    "@ti.kernel\n",
    "def event_ell_gpu(indices: ti.types.ndarray(ndim=2),\n",
    "                  vector: ti.types.ndarray(ndim=1), \n",
    "                  weight: ti.types.ndarray(ndim=1), \n",
    "                  out: ti.types.ndarray(ndim=1)):\n",
    "  weight_0 = weight[0]\n",
    "  ti.loop_config(block_dim=64)\n",
    "  for ij in ti.grouped(indices):\n",
    "      if vector[ij[0]]:\n",
    "          out[ij[1]] += weight_0\n",
    "\n",
    "prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)\n",
    "\n",
    "\n",
    "def test_taichi_op_register():\n",
    "  s = 1000\n",
    "  indices = bm.random.randint(0, s, (s, 1000))\n",
    "  vector = bm.random.rand(s) < 0.1\n",
    "  weight = bm.array([1.0])\n",
    "\n",
    "  out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n",
    "\n",
    "  out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n",
    "\n",
    "  print(out)\n",
    "\n",
    "test_taichi_op_register()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 更多示例\n",
    "对于更多示例, 请参考: \n",
    "- [event/_csr_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/event/_csr_matvec.py)\n",
    "- [sparse/_csr_mv_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/sparse/_csr_mv.py)\n",
    "- [jitconn/_event_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/jitconn/_event_matvec.py)\n",
    "- [jitconn/_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/jitconn/_matvec.py)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 清除Taichi kernel的缓存\n",
    "因为brainpy使用taichi的AOT方法来融合taichi和JAX，所以taichi的kernel会被缓存到系统中。如果你想清除缓存，可以使用以下代码：\n",
    "\n",
    "```python\n",
    "import brainpy.math as bm\n",
    "\n",
    "bm.clean_caches()\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
